#!/usr/bin/env python3
"""
Phase 4: Demographics Predict Capital Control Direction
Research question G: Do aging countries systematically liberalize?

Models:
  1. d_kaopen = β·Z + controls
  2. d_kaopen = β·old_dep + controls
  3. kaopen_level = β·Z + controls
  4. LPM: P(liberalize) = f(Z, controls)
  5. Feldstein-Horioka with Z interactions
  6. By income group
  7. Granger-style: Z predicts future ΔKAOPEN (h=1,...,5)

Output: extensions/output/tables/kaopen_prediction.md
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
MULTILATERAL_DIR = ROOT_DIR / "multilateral"

OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)

sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS


def load_panel():
    path = PROJECT_DIR / "data" / "processed" / "extensions_panel.csv"
    df = pd.read_csv(path)
    df = df[df["year"] <= 2024].copy()
    return df


def stars(p):
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.1:
        return "*"
    return ""


def run_model(df, dep_var, regressors, label):
    """Run PanelGLS and return results dict."""
    cols = [dep_var] + regressors
    sub = df.dropna(subset=cols).copy()
    if len(sub) < 100:
        print(f"  SKIP {label}: only {len(sub)} obs")
        return None

    model = PanelGLS()
    model.fit(sub[dep_var].values, sub[regressors].values,
              sub["iso3"].values, sub["year"].values)
    print(f"\n  {label}")
    model.summary(feature_names=regressors)

    res = model.to_dataframe(feature_names=regressors)
    res["model"] = label
    res["dep_var"] = dep_var
    res["n_obs"] = model.n_obs
    res["n_countries"] = model.n_countries
    res["r_squared"] = model.r_squared
    res["rho"] = model.rho
    return res


def main():
    print("=" * 70)
    print("PHASE 4: Demographics Predict Capital Control Direction")
    print("=" * 70)

    df = load_panel()

    demo_vars = ["Z_1", "Z_2", "Z_3"]
    base_controls = ["nfa_gdp_lag", "log_rel_opw"]
    base_controls = [c for c in base_controls
                     if c in df.columns and df[c].notna().sum() > 200]

    all_results = []

    # ── Model 1: d_kaopen = β·Z + controls ────────────────────────────────
    m1_vars = demo_vars + base_controls
    r = run_model(df, "d_kaopen", m1_vars,
                  "M1: Z → ΔKAOPEN")
    if r is not None:
        all_results.append(r)

    # ── Model 2: d_kaopen = β·old_dep + controls ──────────────────────────
    m2_vars = ["old_dep"] + base_controls
    r = run_model(df, "d_kaopen", m2_vars,
                  "M2: old_dep → ΔKAOPEN")
    if r is not None:
        all_results.append(r)

    # ── Model 3: kaopen level = β·Z + controls ────────────────────────────
    m3_vars = demo_vars + base_controls
    r = run_model(df, "kaopen", m3_vars,
                  "M3: Z → KAOPEN level")
    if r is not None:
        all_results.append(r)

    # ── Model 4: LPM P(liberalize) = f(Z, controls) ──────────────────────
    m4_vars = demo_vars + base_controls
    r = run_model(df, "kaopen_liberalized", m4_vars,
                  "M4: LPM Z → P(liberalize)")
    if r is not None:
        all_results.append(r)

    # ── Model 5: Feldstein-Horioka with Z interactions ─────────────────────
    print("\n  --- Feldstein-Horioka Tests ---")

    # 5a: Basic FH
    fh_vars = ["savings_gdp"]
    r = run_model(df, "investment_gdp", fh_vars,
                  "M5a: FH baseline (I = α + β·S)")
    if r is not None:
        all_results.append(r)

    # 5b: FH with Z
    fh_z_vars = ["savings_gdp"] + demo_vars
    r = run_model(df, "investment_gdp", fh_z_vars,
                  "M5b: FH + Z")
    if r is not None:
        all_results.append(r)

    # 5c: FH with S×Z interactions (does aging weaken S-I correlation?)
    df["savings_x_Z1"] = df["savings_gdp"] * df["Z_1"]
    df["savings_x_Z2"] = df["savings_gdp"] * df["Z_2"]
    df["savings_x_Z3"] = df["savings_gdp"] * df["Z_3"]
    fh_int_vars = (["savings_gdp"] + demo_vars
                   + ["savings_x_Z1", "savings_x_Z2", "savings_x_Z3"])
    r = run_model(df, "investment_gdp", fh_int_vars,
                  "M5c: FH + S×Z interactions")
    if r is not None:
        all_results.append(r)

    # 5d: FH with S×kaopen (standard result: openness weakens FH)
    df["savings_x_kaopen"] = df["savings_gdp"] * df["kaopen"]
    fh_kaopen_vars = ["savings_gdp", "kaopen", "savings_x_kaopen"]
    r = run_model(df, "investment_gdp", fh_kaopen_vars,
                  "M5d: FH + S×KAOPEN")
    if r is not None:
        all_results.append(r)

    # 5e: Triple — S × Z × KAOPEN
    df["savings_x_Z1_x_kaopen"] = (df["savings_gdp"] * df["Z_1"]
                                    * df["kaopen"])
    fh_triple_vars = (["savings_gdp", "kaopen"] + demo_vars
                      + ["savings_x_kaopen", "savings_x_Z1",
                         "savings_x_Z1_x_kaopen"])
    r = run_model(df, "investment_gdp", fh_triple_vars,
                  "M5e: FH triple (S×Z×KAOPEN)")
    if r is not None:
        all_results.append(r)

    # ── Model 6: By income group ──────────────────────────────────────────
    # Use GDP per capita to split: OECD proxy via high GDP pc
    if "gdp_pc_ppp" in df.columns:
        median_gdppc = df.groupby("year")["gdp_pc_ppp"].median()
        df["high_income"] = df.apply(
            lambda row: row["gdp_pc_ppp"] > median_gdppc.get(row["year"], 0)
            if pd.notna(row["gdp_pc_ppp"]) else np.nan, axis=1)

        for hi, label in [(1, "high-income"), (0, "low-income")]:
            sub = df[df["high_income"] == hi].copy()
            r = run_model(sub, "d_kaopen", demo_vars + base_controls,
                          f"M6: Z → ΔKAOPEN ({label})")
            if r is not None:
                all_results.append(r)

    # ── Model 7: Granger-style — Z predicts future ΔKAOPEN ────────────────
    print("\n  --- Granger-style: Z predicts future ΔKAOPEN ---")
    granger_results = []
    for h in range(1, 6):
        df[f"d_kaopen_f{h}"] = df.groupby("iso3")["d_kaopen"].shift(-h)
        r = run_model(df, f"d_kaopen_f{h}", demo_vars + base_controls,
                      f"M7: Z → ΔKAOPEN(t+{h})")
        if r is not None:
            all_results.append(r)
            # Extract Z_1 coefficient for Granger summary
            z1_row = r[r["variable"] == "Z_1"]
            if len(z1_row) > 0:
                granger_results.append({
                    "horizon": h,
                    "Z1_coef": z1_row.iloc[0]["coefficient"],
                    "Z1_pval": z1_row.iloc[0]["p_value"],
                    "r_squared": r.iloc[0]["r_squared"],
                })

    if granger_results:
        gdf = pd.DataFrame(granger_results)
        print("\n  Granger Summary (Z₁ → future ΔKAOPEN):")
        print(gdf.to_string(index=False, float_format="%.4f"))

    # ── Combine and save ───────────────────────────────────────────────────
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(OUT_TABLES / "kaopen_prediction_results.csv",
                          index=False)
        write_markdown_table(results_df, granger_results)

    print("\n  Done. Results saved to extensions/output/tables/")


def write_markdown_table(results_df, granger_results):
    """Write formatted markdown results table."""
    lines = ["# Demographics Predict Capital Account Liberalization", ""]

    models = results_df["model"].unique()
    lines.append("## Model Comparison")
    lines.append("")
    lines.append("| Model | Dep Var | N | Countries | R² | ρ |")
    lines.append("|-------|---------|---|-----------|----|----|")
    for m in models:
        sub = results_df[results_df["model"] == m].iloc[0]
        lines.append(f"| {m} | {sub['dep_var']} | {sub['n_obs']:,} | "
                     f"{sub['n_countries']} | {sub['r_squared']:.3f} | "
                     f"{sub['rho']:.3f} |")

    # Key coefficients
    lines.append("")
    lines.append("## Key Coefficients")
    lines.append("")
    key_vars = ["Z_1", "Z_2", "Z_3", "old_dep", "savings_gdp",
                "savings_x_kaopen", "savings_x_Z1",
                "savings_x_Z1_x_kaopen"]
    lines.append("| Variable | Model | Coef | SE | p-value |")
    lines.append("|----------|-------|------|----|---------|")
    for v in key_vars:
        sub = results_df[results_df["variable"] == v]
        for _, row in sub.iterrows():
            if pd.isna(row["p_value"]):
                continue
            sig = stars(row["p_value"])
            lines.append(f"| {v} | {row['model']} | "
                         f"{row['coefficient']:.4f}{sig} | "
                         f"{row['std_error']:.4f} | {row['p_value']:.4f} |")

    # Granger summary
    if granger_results:
        lines.append("")
        lines.append("## Granger Predictability: Z₁ → Future ΔKAOPEN")
        lines.append("")
        lines.append("| Horizon | Z₁ Coef | p-value | R² |")
        lines.append("|---------|---------|---------|-----|")
        for g in granger_results:
            sig = stars(g["Z1_pval"])
            lines.append(f"| t+{g['horizon']} | {g['Z1_coef']:.4f}{sig} | "
                         f"{g['Z1_pval']:.4f} | {g['r_squared']:.3f} |")

    # Feldstein-Horioka summary
    lines.append("")
    lines.append("## Feldstein-Horioka Results")
    lines.append("")
    lines.append("The savings retention coefficient (β in I = α + β·S) "
                 "measures capital mobility.")
    lines.append("β = 1 implies no capital mobility; β → 0 implies "
                 "perfect mobility.")
    lines.append("")
    fh_models = [m for m in models if "FH" in m]
    lines.append("| Model | Savings Coef | SE | p-value | R² |")
    lines.append("|-------|-------------|-----|---------|-----|")
    for m in fh_models:
        sub = results_df[(results_df["model"] == m)
                         & (results_df["variable"] == "savings_gdp")]
        if len(sub) == 0:
            continue
        row = sub.iloc[0]
        sig = stars(row["p_value"]) if pd.notna(row["p_value"]) else ""
        lines.append(f"| {m} | {row['coefficient']:.4f}{sig} | "
                     f"{row['std_error']:.4f} | {row['p_value']:.4f} | "
                     f"{row['r_squared']:.3f} |")

    md_path = OUT_TABLES / "kaopen_prediction.md"
    md_path.write_text("\n".join(lines))
    print(f"\n  Markdown table: {md_path}")


if __name__ == "__main__":
    main()
